from hpo.casmo.bgpbt import BGPBT
import pandas as pd
import logging
from hpo.utils import get_reward_from_trajectory, is_large
import numpy as np
import ConfigSpace as CS
import shutil
import os
from copy import deepcopy
import random


class BGPBTVariant(BGPBT):
    """A variant of Casmo that distill after learning plateaus."""

    def __init__(self, env, log_dir,
                 max_timesteps: int = None,
                 pop_size: int = 4,
                 n_init: int = None,
                 verbose: bool = False,
                 ard=False,
                 t_ready: int = None,
                 n_distillation_timesteps: int = int(5e5),
                 quantile_fraction: float = .25,
                 seed: int = None,
                 use_reward_fraction: float = 0.,
                 existing_policy: str = 'resume',
                 backtrack: bool = False,
                 patience: int = 15,
                 distill_every: int = int(3e6),
                 arch_policy: str = 'static',
                 max_distillation: int = 2,
                 t_ready_end: int = None,
                 ):
        super(BGPBTVariant, self).__init__(env, log_dir, max_timesteps, pop_size,
                                           n_init, verbose, ard, t_ready, quantile_fraction,
                                           seed, use_reward_fraction, existing_policy,
                                           backtrack=backtrack, schedule_t_ready=False,
                                           t_ready_end=t_ready_end,
                                           )
        self.n_distillation_timesteps = n_distillation_timesteps
        self.distill_every = distill_every
        self.max_distillation = max_distillation
        self.patience = patience
        assert arch_policy in ['static', 'schedule']
        self.arch_policy = arch_policy
        # early stop params
        if self.resumed:
            self.last_distill_timestep = self.df[self.df.n_distills == self.n_distills][self.budget_type].min()
            self.best_loss = self.df[self.df.n_distills == self.n_distills].R.min()
            best_t = self.df[(self.df.n_distills == self.n_distills) * (self.df.R == self.best_loss)]['t'].iloc[-1]
            self.n_fail = self.df.t.max() - best_t
            if self.arch_policy == 'schedule':
                self.policy_net = [32] * min(5, self.n_distills + 2)
                self.value_net = [256] * min(6, self.n_distills + 3)
            elif self.arch_policy == 'static':
                self.policy_net = [32] * 4
                self.value_net = [256] * 5
            logging.info(f'Resumed from existing: Last distill={self.last_distill_timestep}. n_fail={self.n_fail}. '
                         f'n_distills={self.n_distills}. Last timestep={self.df[self.budget_type].max()}. Current policy net = {self.policy_net},'
                         f'Current value net={self.value_net}.')
        else:
            self.last_distill_timestep = 0
            self.n_fail = 0
            self.best_loss = float('inf')
            if self.arch_policy == 'schedule':
                self.policy_net = [32] * min(5, self.n_distills + 2)
                self.value_net = [256] * min(6, self.n_distills + 3)
            elif self.arch_policy == 'static':
                self.policy_net = [32] * 4
                self.value_net = [256] * 5

    def search_init(self, best_agents=None):
        """Search for a good initialization by doing end-to-end (i.e. non-population based) BO for a short timeframe.
        Note some treatment is slightly different, as here we deal with potentially different architectures."""
        if self.n_distills > 0:  # distillation stages
            if best_agents is None or len(best_agents) == 0: raise ValueError()

        def f(configs, num_timesteps=None, ckpt_paths=None, teacher_configs=None, replace_teacher=False,
              student_policy_net=None, student_value_net=None):
            if student_policy_net is None: student_policy_net = self.policy_net
            if student_value_net is None: student_value_net = self.value_net
            if ckpt_paths is None:
                ckpt_paths = [os.path.join(f'{self.log_dir}/pb2_checkpoints',
                                           f'{self.env.env_name}_InitConfig{i}_Stage{self.n_distills}.pt')
                              for i in range(0, len(configs))]
            if num_timesteps is None:
                num_timesteps = self.t_ready_start
            n_large_models = sum([is_large(c['config']) for c in self.pop.values()])
            max_parallel = self.env.max_parallel // 2 if n_large_models / len(
                configs) >= 0.25 else self.env.max_parallel
            logging.info(f'Running config={configs} with n_parallel={max_parallel}')

            if self.n_distills == 0:  # search init for the very beginning
                trajectories = self.env.train_batch(configs=configs, seeds=[self.seed] * len(configs),
                                                    nums_timesteps=[num_timesteps] * len(configs),
                                                    max_parallel=max_parallel,
                                                    policy_hidden_layer_sizes=self.policy_net,
                                                    v_hidden_layer_sizes=self.value_net,
                                                    checkpoint_paths=ckpt_paths)
            else:
                assert teacher_configs is not None and ckpt_paths is not None, \
                    'For distillation, teacher_configs and teacher_ckpts must be specified!'
                trajectories = self.env.distill_batch(teacher_configs=teacher_configs,
                                                      student_configs=configs,
                                                      seeds=[self.seed] * len(configs),
                                                      distill_nums_timesteps=[num_timesteps] * len(configs),
                                                      distill_total_num_timesteps=self.n_distillation_timesteps,
                                                      train_nums_timesteps=[0] * len(configs),
                                                      checkpoint_paths=ckpt_paths,
                                                      fixed_teacher_params={
                                                          'policy_hidden_layer_sizes': self.policy_net,
                                                          'v_hidden_layer_sizes': self.value_net
                                                      },
                                                      fixed_student_params={
                                                          'policy_hidden_layer_sizes': student_policy_net,
                                                          'v_hidden_layer_sizes': student_value_net,
                                                      },
                                                      # note that the teacher config does not contain the arch info as it is from HPO search space only
                                                      max_parallel=max(1, max_parallel // 2),
                                                      replace_teacher=replace_teacher,)
            return trajectories

        init_size = max(self.n_init, self.pop_size)
        if self.n_distills == 0:
            init_configs = [self.env.config_space.sample_configuration() for _ in range(init_size)]
            trajectories = f(init_configs)
            costs = [-get_reward_from_trajectory(np.array(t['y'], dtype=np.float), use_last_fraction=self.use_reward) for t in
                     trajectories]
            rl_rewards = [get_reward_from_trajectory(np.array(t['y'], dtype=np.float), 0) for t in trajectories]
            top_config_ids = np.argsort(costs).tolist()
            for i, (agent, stats) in enumerate(self.pop.items()):
                self.pop[agent] = {
                    'done': False,
                    'config': init_configs[top_config_ids[i]],
                    'path': os.path.join(f'{self.log_dir}/pb2_checkpoints',
                                         f'{self.env.env_name}_seed{self.env.seed}_Agent{agent}.pt'),
                    'config_source': 'random',
                    'distill': True,  # signals that this should be distilled in the next iteration
                }

                shutil.copy(os.path.join(f'{self.log_dir}/pb2_checkpoints',
                                         f'{self.env.env_name}_InitConfig{top_config_ids[i]}_Stage{self.n_distills}.pt'),
                            os.path.join(f'{self.log_dir}/pb2_checkpoints',
                                         f'{self.env.env_name}_seed{self.env.seed}_Agent{agent}.pt'))
            # delete the initialization checkpoints
            for i in range(len(top_config_ids)):
                os.remove(os.path.join(f'{self.log_dir}/pb2_checkpoints', f'{self.env.env_name}_InitConfig{top_config_ids[i]}_Stage{self.n_distills}.pt'))
            for i in range(len(init_configs)):
                config = init_configs[i]
                config_array = config.get_array()
                rl_reward = rl_rewards[i]
                scalar_steps = trajectories[i]['x'][-1] + self.last_distill_timestep
                d = pd.DataFrame(columns=self.df.columns)
                agent_number = top_config_ids.index(i)
                if agent_number >= self.pop_size: agent_number = -1
                # agent_number = -1 if (i not in top_config_ids) or (i >= self.pop_size) else
                path = self.pop[agent_number]['path'] if agent_number >= 0 else np.nan
                d.loc[0] = [agent_number, 1, scalar_steps, costs[i], rl_reward, config_array.tolist(), path,
                            config,
                            'random', False, self.n_distills, np.nan, np.nan]
                self.df = pd.concat([self.df, d]).reset_index(drop=True)
                logging.info(
                    "\nAgent: {}, Timesteps: {}, Cost: {}\n".format(agent_number, scalar_steps, costs[i], ))
        else:
            # randomly generate new configs based on the best configs found previously.
            if self.arch_policy == 'schedule' and len(self.policy_net) < 5 and len(self.value_net) < 6:
                student_policy_net = self.policy_net + [self.policy_net[-1]]
                student_value_net = self.value_net + [self.value_net[-1]]
            else:
                student_policy_net = self.policy_net
                student_value_net = self.value_net

            init_configs = [deepcopy(self.pop[random.choice(np.arange(self.pop_size))]['config']) for _ in range(self.pop_size)]

            teacher_configs, teacher_ckpts = [], []
            for i in range(len(init_configs)):
                best_agent = np.random.choice(best_agents)
                teacher_ckpt = f"{self.pop[best_agent]['path']}_forDistillAgent{i}"
                shutil.copy(self.pop[best_agent]['path'], teacher_ckpt)
                teacher_configs.append(self.pop[best_agent]['config'])
                teacher_ckpts.append(teacher_ckpt)
            logging.info(f'run_init distill student config={init_configs}. teacher configs = {teacher_configs}')
            best_configs_for_distill = deepcopy(init_configs)
            distill_ckpts = deepcopy(teacher_ckpts)

            full_trajectories = f(best_configs_for_distill, num_timesteps=self.n_distillation_timesteps,
                                  teacher_configs=teacher_configs,
                                  student_value_net=student_value_net, student_policy_net=student_policy_net,
                                  ckpt_paths=teacher_ckpts, replace_teacher=True)  # replace teacher at the final iter.
            self.policy_net = student_policy_net
            self.value_net = student_value_net

            new_pop = deepcopy(self.pop)
            current_t = self.df.t.max() + 1
            for idx, (agent, stats) in enumerate(new_pop.items()):
                logging.info(
                    f'Assigning {best_configs_for_distill[idx]} stored at {distill_ckpts[idx]} to Agent {agent}')
                shutil.copy(distill_ckpts[idx], self.pop[agent]['path'])

                new_config = deepcopy(best_configs_for_distill[idx])
                new_pop[agent] = {
                    'done': False,
                    'config': new_config,
                    'path': self.pop[agent]['path'],
                    'config_source': 'distilled',
                    'distill': False,  # signals that this should be distilled in the next iteration
                }
                # record the rewards in self.df
                d = pd.DataFrame(columns=self.df.columns)
                rl_reward = get_reward_from_trajectory(np.array(full_trajectories[agent]['y']), 0)
                cost = -get_reward_from_trajectory(np.array(full_trajectories[agent]['y']),
                                                   use_last_fraction=self.use_reward)
                max_t = self.df[self.df.Agent == agent][self.budget_type].max()
                scalar_steps = full_trajectories[agent]['x'][-1] + max_t
                d.loc[0] = [agent, current_t, scalar_steps, cost, rl_reward,
                            new_config.get_array().tolist(), self.pop[agent]['path'], new_config, 'random', False,
                            self.n_distills, np.nan, np.nan]
                self.df = pd.concat([self.df, d]).reset_index(drop=True)
            self.pop = new_pop

            for teacher_ckpt in teacher_ckpts:
                os.remove(teacher_ckpt)

    def run(self, ):
        # conf is the internal representation of the array: with continuous/integer hyperparameteres scaled between 0 and 1.
        # conf_ is a ConfigSpace representation that is human-readable.
        all_done = False
        distill_at_this_step = False

        # specify the checkpoint path of all agents
        while not all_done:
            # check the status of the asynchronously run results
            n_large_models = sum([is_large(c['config']) for c in self.pop.values()])
            max_parallel = self.env.max_parallel // 2 if n_large_models / len(
                self.pop) >= 0.25 else self.env.max_parallel

            if distill_at_this_step:
                n = max(int(self.pop_size * self.quantile_fraction), 1)
                logging.info(f'Modifying the nets at this iteration & distilling.')

                last_entries = self.df[
                    self.df['t'] == self.df.t.max()]  # index entire population based on last set of runs
                last_entries = last_entries.iloc[:self.pop_size]  # only want the original entries
                ranked_last_entries = last_entries.sort_values(by=['R'], ignore_index=True,
                                                               ascending=False)  # rank last entries
                best_agents = list(ranked_last_entries.iloc[-n:]['Agent'].values)

                self.search_init(best_agents=best_agents)
                distill_at_this_step = False

            logging.info(
                f'Max parallel for this iteration={max_parallel}. Last distillation step={self.last_distill_timestep}.')
            logging.info(
                f'Running config={[c["config"] for c in self.pop.values()]}. Policy_net={self.policy_net}, Value_net={self.value_net}')
            logging.info(f'Running config={[c["config"] for c in self.pop.values()]}')

            if self.t_ready_end == self.t_ready_start or self.t_ready_end is None:
                t_ready = self.t_ready_start
            else:
                t_ready = int(self.t_ready_start + (self.t_ready_end - self.t_ready_start) / self.max_timesteps * self.df[self.budget_type].max())

            results_values = self.env.train_batch(configs=[c['config'] for c in self.pop.values()],
                                                  seeds=[self.seed] * len(self.pop),
                                                  nums_timesteps=[t_ready] * len(self.pop),
                                                  checkpoint_paths=[c['path'] for c in self.pop.values()],
                                                  policy_hidden_layer_sizes=self.policy_net,
                                                  v_hidden_layer_sizes=self.value_net,
                                                  max_parallel=max_parallel)

            results_keys = list(self.pop.keys())
            results = dict(zip(results_keys, results_values))

            if self.df.t.empty:
                t = 1
            else:
                t = self.df.t.max() + 1

            for agent in self.pop.keys():
                # negative sign to convert the reward maximization to a minimisation problem
                # final reward is the
                if self.pop[agent]['done']:
                    logging.info(f'Skipping completed agent {agent}.')
                    continue
                # if self.df[self.df['Agent'] == agent].t.empty:  t = 1
                # else: t = self.df[self.df['Agent'] == agent].t.max() + 1

                final_cost = -get_reward_from_trajectory(results[agent]['y'], self.use_reward, 0.)
                rl_reward = get_reward_from_trajectory(results[agent]['y'], 1, 0.)
                final_timestep = results[agent]['x'][-1]

                if self.df[self.df['Agent'] == agent].empty:
                    scalar_steps = final_timestep
                else:
                    scalar_steps = final_timestep + self.df[self.df['Agent'] == agent][self.budget_type].max()
                logging.info("\nAgent: {}, Timesteps: {}, Cost: {}\n".format(agent, scalar_steps, final_cost))

                conf_array = self.pop[agent]['config'].get_array().tolist()
                conf = self.pop[agent]['config']
                config_source = self.pop[agent]['config_source']
                d = pd.DataFrame(columns=self.df.columns)
                d.loc[0] = [agent, t, scalar_steps, final_cost, rl_reward, conf_array,
                            self.pop[agent]['path'], conf, config_source, False, self.n_distills,
                            self.policy_net, self.value_net]
                self.df = pd.concat([self.df, d]).reset_index(drop=True)

                if self.df[self.df['Agent'] == agent][self.budget_type].max() >= self.max_timesteps:
                    self.pop[agent]['done'] = True

            # update the trust region based on the results of the agents from previous runs, before exploitation
            self.adjust_tr_length()

            best_loss = self.df[self.df.n_distills == self.n_distills]['R'].min()
            if self.backtrack and self.env.env_name not in ['dummy', 'synthetic']:
                if best_loss < self.best_cost:
                    self.best_cost = best_loss
                    overall_best_agent = \
                        self.df[(self.df['R'] == best_loss) & (self.df['n_distills'] == self.n_distills)].iloc[-1]
                    shutil.copy(overall_best_agent['path'], self.best_checkpoint_dir)

            # self.pop = OrderedDict({k: v for k, v in self.pop.items() if not v['excluded']})
            # exploitation -- copy the weights and etc.
            for agent in self.pop.keys():
                old_conf = self.pop[agent]['config'].get_array().tolist()
                self.pop[agent], copied = self.exploit(agent, )
                # here we need to include a way to account for changes in the data.
                new_conf = self.pop[agent]['config'].get_array().tolist()
                if not np.isclose(0, np.nansum(np.array(old_conf) - np.array(new_conf))):
                    logging.info("changing conf for agent: {}".format(agent))
                    new_row = self.df[(self.df['Agent'] == copied) & (self.df['t'] == self.df.t.max())]
                    new_row['Agent'] = agent
                    # new_row['path'] = self.pop[agent]['path']
                    logging.info(f"new row conf old: {new_row['conf']}")
                    logging.info(f"new row conf new: {[new_conf]}")
                    new_row['conf'] = [new_conf]
                    new_row['conf_'] = [CS.Configuration(self.env.config_space, vector=new_conf)]
                    self.df = pd.concat([self.df, new_row]).reset_index(drop=True)
                    logging.info(f"new config: {new_conf}")

            all_done = np.array([self.pop[agent]['done'] for agent in self.pop.keys()]).all()
            # save intermediate results
            self.df.to_csv(os.path.join(self.log_dir, f'stats_seed_{self.seed}_intermediate.csv'))
            self.last_distill_timestep = self.df[self.df.n_distills == self.n_distills][
                self.budget_type].min()  # record the timestep as the last time we undergo distillation.
            t_max = self.df[self.budget_type].max()
            best_loss = self.df[self.df['n_distills'] == self.n_distills].R.min()
            if self.df[
                (self.df[self.budget_type] == t_max) & (self.df['n_distills'] == self.n_distills)].R.min() == best_loss:
                self.n_fail = 0
            else:
                self.n_fail += 1
            # restart when the casmo trust region is below threshold
            if self.n_distills < self.max_distillation and \
                    (self.n_fail >= self.patience or t_max - self.last_distill_timestep > self.distill_every):
                distill_at_this_step = True
                self.n_distills += 1
                self.n_fail = 0
                self.best_cost = float('inf')
                logging.info('Start distillation in the next iteration..')
            logging.info(f'n_fail: {self.n_fail}')
        return self.df
